//InstDecode.scala
package test

import chisel3._
import chisel3.util._

class InstDecode extends Module{
  val io = IO(new Bundle{
    val ifid = Input(new IFID)
    val idex = Output(new IDEX)

    val wbid = Input(new WBID)

    val BJ = Output(new BranchJump)
    val RD_ex = Input(new ReDirection)
    val RD_mem = Input(new ReDirection)
    val RD_wb = Input(new ReDirection)

    val Bubble_id = Output(Bool())
    val enable = Input(Bool())
    
    //Show the results of lab1 code about vector dot products
    val select = Input(UInt(4.W))
    val results = Output(UInt(4.W))
  })

  //ids mesns InstDecodeStream
  val ids = RegEnable(io.ifid, io.enable)

  val inst = ids.Inst
  val I_imm = inst(31,20)
  val U_imm = inst(31,12)
  val shamt = inst(24,20)
  val rs1 = inst(19,15)
  val rs2 = inst(24,20)
  val regdes = inst(11,7)

  val regfile = new RegFile
  val src1 = regfile.read(rs1)
  val src2 = regfile.read(rs2)
  regfile.write(io.wbid.wen, io.wbid.regdes, io.wbid.wbdata)
  val simm = SignExt(I_imm, 32)
  val decoder = new Decoder(inst)
  val wen = decoder.wen
  val aluop = decoder.aluop

//Show the results of lab1 code about vector dot products
  val reg31 = regfile.read(31.U(5.W))
  val reg31_3_0 = SignExt(io.select(0), 4) & reg31(3,0)
  val reg31_7_4 = SignExt(io.select(1), 4) & reg31(7,4)
  val reg31_11_8 = SignExt(io.select(2), 4) & reg31(11,8)
  val reg31_15_12 = SignExt(io.select(3), 4) & reg31(15,12)
  io.results := reg31_3_0 | reg31_7_4 | reg31_11_8 | reg31_15_12
//Show the results of lab1 code about vector dot products





  val RD_src1_hitEX  = io.RD_ex.wen  && (rs1 === io.RD_ex.regdes)
  val RD_src1_hitMEM = io.RD_mem.wen && (rs1 === io.RD_mem.regdes)
  val RD_src1_hitWB  = io.RD_wb.wen  && (rs1 === io.RD_wb.regdes)

  
  val RD_src2_hitEX  = io.RD_ex.wen  && (rs2 === io.RD_ex.regdes)
  val RD_src2_hitMEM = io.RD_mem.wen && (rs2 === io.RD_mem.regdes)
  val RD_src2_hitWB  = io.RD_wb.wen  && (rs2 === io.RD_wb.regdes)

  io.Bubble_id := RD_src1_hitEX || RD_src2_hitEX

  io.idex.Valid := ids.Valid
  io.idex.Inst  := Mux(io.enable, ids.Inst, 0.U(32.W))
  io.idex.PC    := Mux(io.enable, ids.PC, 0.U(32.W))
  io.idex.aluop := Mux(io.enable, aluop, 0.U(32.W))
  io.idex.src1  := Mux(RD_src1_hitEX,  io.RD_ex.wbdata,
                   Mux(RD_src1_hitMEM, io.RD_mem.wbdata,
                   Mux(RD_src1_hitWB,  io.RD_wb.wbdata, src1)))
  io.idex.src2  := Mux(RD_src2_hitEX,  io.RD_ex.wbdata,
                   Mux(RD_src2_hitMEM, io.RD_mem.wbdata,
                   Mux(RD_src2_hitWB,  io.RD_wb.wbdata, src2)))
  io.idex.simm  := simm
  io.idex.shamt := shamt
  io.idex.wen   := io.enable && wen
  io.idex.regdes:= regdes

  val branch_offset = SignExt(Cat(inst(31), inst(7), inst(30,25), inst(11,8), 0.U(1.W)), 32)
  val jal_offset    = SignExt(Cat(inst(31), inst(19,12), inst(20), inst(30,21), 0.U(1.W)), 32)
  val jalr_offset   = SignExt(inst(31,20), 32) & ~(1.U(32.W))
  val branch_target = ids.PC + branch_offset
  val jal_target    = ids.PC + jal_offset
  val jalr_target   = io.idex.src1 + jalr_offset

  val branch  = decoder.BEQ && (io.idex.src1 === io.idex.src2) ||
                decoder.BNE && (io.idex.src1 =/= io.idex.src2) ||
                decoder.BGEU && (io.idex.src1 >= io.idex.src2) ||
                decoder.BLTU && (io.idex.src1 < io.idex.src2) ||
                decoder.BGE && (io.idex.src1.asSInt >= io.idex.src2.asSInt) ||
                decoder.BLT && (io.idex.src1.asSInt < io.idex.src2.asSInt)
  val jump    = decoder.JAL || decoder.JALR
  val vld     = branch || jump
  val target  = Mux(branch, branch_target, 0.U(32.W)) |
                Mux(decoder.JAL, jal_target, 0.U(32.W)) |
                Mux(decoder.JALR, jalr_target, 0.U(32.W))
  io.BJ.vld    := vld
  io.BJ.target := target

}

